#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 30 17:17:54 2023

@author: jcfq2
"""

import _pickle as cPickle
import scipy.ndimage            #For the .zoom function
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage, misc 
from skimage.transform import rescale, resize, downscale_local_mean



def cart2pol(x,y):
    rho = np.sqrt(x**2 + y ** 2)
    phi = np.arctan2(y,x)
    return(rho,phi)

def pol2cart(rho,phi):
    x = rho * np.cos(phi)
    y = rho * np.sin(phi)
    return(x,y)


def phase_shift(vel0_180,vel90_270,phase):
    [mag, angle] = cart2pol(vel90_270,vel0_180)
    newangle = (angle + np.radians(phase))
    [newvel90_270,newvel0_180] = pol2cart(mag,newangle)
    return(newvel0_180)


#%% Load in the saved data files
"""
Bin 4 Gaussian fits
"""
filedir = 'Gaussian_fits/magnetic_phase/10-pixel_smooth/'

filename1='northern_bin1_315_45_order19.pickle'
filename2='northern_bin2_45_135_order19.pickle'
filename3='northern_bin3_135_225_order19.pickle'
filename4='northern_bin4_225_315_order19.pickle'


filepath = filedir+filename1
with open(filepath, 'rb') as opened_file:
    bin0 = cPickle.load(opened_file)

filepath = filedir+filename2
with open(filepath, 'rb') as opened_file:
    bin90 = cPickle.load(opened_file)

filepath = filedir+filename3
with open(filepath, 'rb') as opened_file:
    bin180 = cPickle.load(opened_file)

filepath = filedir+filename4
with open(filepath, 'rb') as opened_file:
    bin270 = cPickle.load(opened_file)


R=75000
c=300000


int0 = bin0[0][:,:,0]
pos0 = bin0[1][:,:,0]
int90 = bin90[0][:,:,0]
pos90 = bin90[1][:,:,0]
int180 = bin180[0][:,:,0]
pos180 = bin180[1][:,:,0]
int270 = bin270[0][:,:,0]
pos270 = bin270[1][:,:,0]

int_err = bin0[0][:,:,1]
pos_err = bin0[1][:,:,1]

plt.figure()
plt.imshow(int0,vmin=0,vmax=0.002)
# plt.imshow(pos,vmin=24,vmax=28)

vel0=pos0*1/R*(-c)
vel90=pos90*1/R*(-c)
vel180=pos180*1/R*(-c)
vel270=pos270*1/R*(-c)


vel0_180=vel0-vel180
vel90_270=vel90-vel270

vel0_180[vel0_180>4]=0
vel0_180[vel0_180<-4]=0
vel90_270[vel90_270>4]=0
vel90_270[vel90_270<-4]=0



# vel0_zero = np.median(vel0[20:100,75:125])
# vel90_zero = np.median(vel90[20:100,75:125])
# vel180_zero = np.median(vel180[20:100,75:125])
# vel270_zero = np.median(vel270[20:100,75:125])

# vel0=vel0-vel0_zero

mirror_vel0_180 = (vel0_180+np.flip(vel0_180,axis=1))/2
mirror_vel90_270 = (vel90_270+np.flip(vel90_270,axis=1))/2

mirrormirror_vel0_180 = (mirror_vel0_180+np.flip(mirror_vel0_180,axis=0))/2
mirrormirror_vel90_270 = (mirror_vel90_270+np.flip(mirror_vel90_270,axis=0))/2


# plt.imshow(pos,vmin=24,vmax=28)

velocity_mirror='TrueTrue'

mirror_str = ""
if velocity_mirror == 'True':
    mirror_str = "_mirror"

    vel0_180=mirror_vel0_180
    vel90_270=mirror_vel90_270

if velocity_mirror == 'TrueTrue':
    mirror_str = "_mirrormirror"

    vel0_180=mirrormirror_vel0_180
    vel90_270=mirrormirror_vel90_270

# vel0_180=vel0_180[:,50:50+120]
# vel90_270=vel90_270[:,50:50+120]

# [mag, angle] = cart2pol(vel90_270,vel0_180)


plt.figure()
plt.imshow(vel0_180,vmin=-1.5,vmax=1.5,cmap='seismic_r',origin='lower')

plt.figure()
plt.imshow(vel90_270,vmin=-1.5,vmax=1.5,cmap='seismic_r',origin='lower')






# %% produce an average phase for a range of shift values

rs_vel0_180 = rescale(vel0_180, (0.74486,1), anti_aliasing=True)
rs_vel90_270 = rescale(vel90_270, (0.74486,1), anti_aliasing=True)

# s_vel0_180 = rs_vel0_180[9:,60:140]
# s_vel90_270 = rs_vel90_270[9:,60:140]

s_vel0_180 = rs_vel0_180[1:,100-44:100+44]
s_vel90_270 = rs_vel90_270[1:,100-44:100+44]


blank_phase=np.zeros([88,88])

# full rotation
full_rotation = 'False'
if full_rotation == 'True':
    for i in range(358):
        phase=i+1
    
    # newangle = (angle + np.radians(phase))
    
    # print(np.degrees(angle[50,120]),np.degrees(newangle[60,100]))
    
    
    # [newvel90_270,newvel0_180] = pol2cart(mag,newangle)
        
        # rs_vel0_180 = rescale(vel0_180, (0.74486,1), anti_aliasing=True)
        # rs_vel90_270 = rescale(vel90_270, (0.74486,1), anti_aliasing=True)
        
        # s_vel0_180 = rs_vel0_180[9:,60:140]
        # s_vel90_270 = rs_vel90_270[9:,60:140]
        
        
        # plt.figure()
        # plt.imshow(rs_vel0_180,vmin=-5,vmax=5)
        # plt.figure()
        # plt.imshow(s_vel0_180,vmin=-5,vmax=5)
        
        
        rvel0_180 = ndimage.rotate(s_vel0_180, phase, mode = 'constant')
        rvel90_270 = ndimage.rotate(s_vel90_270, phase-90, mode = 'constant')
        
        
        coi_x = rvel0_180[0,:].size/2
        coi_y = rvel0_180[:,0].size/2
        
        # print(int(coi_x),int(coi_y))
        
        # crop the extra info added by the rotation 
        rrvel0_180=rvel0_180[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
        rrvel90_270=rvel90_270[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
        
        
        newvel_phase = phase_shift(rrvel0_180,rrvel90_270,phase)
        blank_phase = blank_phase + newvel_phase
        fig, (ax1, ax2) = plt.subplots(2)
        ax1.imshow(newvel_phase,vmin=-5,vmax=5)
        ax2.imshow(blank_phase/phase,vmin=-5,vmax=5)
    # plt.figure()
    # plt.imshow(newvel90_270,vmin=-5,vmax=5)





# %%  set up a loop to save out all the data from every night




full_calculation = 'True'
if full_calculation == 'True':

    
    date=np.array(['25jul','01aug','07aug','12aug','14aug','20aug','25aug'])
    startphases=np.array([104,271, 54, 60,225, 10, 20])
    endphases=  np.array([193, 21,110,140,304, 62, 89])
    
    
    
    
    for dd in range(date.size):
        day=date[dd]
        start_phase = startphases[dd]
        end_phase = endphases[dd]
    
        
       
        
        blank_phase=np.zeros([88,88])
        
        # full rotation           
            
           
        if end_phase < start_phase: end_phase = end_phase+360
        
        
        n_steps = end_phase-start_phase+1
        
        for i in range(n_steps):
            phase=start_phase+i
        
        # newangle = (angle + np.radians(phase))
        
        # print(np.degrees(angle[50,120]),np.degrees(newangle[60,100]))
        
        
        # [newvel90_270,newvel0_180] = pol2cart(mag,newangle)
            
            
            
            # plt.figure()
            # plt.imshow(rs_vel0_180,vmin=-5,vmax=5)
            # plt.figure()
            # plt.imshow(s_vel0_180,vmin=-5,vmax=5)
            
            
            rvel0_180 = ndimage.rotate(s_vel0_180, phase, mode = 'constant')
            rvel90_270 = ndimage.rotate(s_vel90_270, phase-90, mode = 'constant')
            
            
            coi_x = rvel0_180[0,:].size/2
            coi_y = rvel0_180[:,0].size/2
            
            # print(int(coi_x),int(coi_y))
            
            rrvel0_180=rvel0_180[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
            rrvel90_270=rvel90_270[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
            
            
            newvel_phase = phase_shift(rrvel0_180,rrvel90_270,phase)
            blank_phase = blank_phase + newvel_phase
        #     fig, (ax1, ax2) = plt.subplots(2)
        #     ax1.imshow(newvel_phase,vmin=-5,vmax=5)
        #     ax2.imshow(blank_phase/(i+1),vmin=-5,vmax=5)
        # # plt.figure()
        # plt.imshow(newvel90_270,vmin=-5,vmax=5)
    
    
        blank_phase = blank_phase/n_steps
    
        
        # re-inject this into the true size and shape of the data
        
        
        rs_finalphasevel = np.zeros_like(rs_vel0_180)
        rs_finalphasevel[1:,100-44:100+44] = blank_phase
        
        finalphasevel = rescale(rs_finalphasevel, (1/0.74,1), anti_aliasing=True)
        # NB - not exactly the same due to a scaling issue resetting to 119 instead of 120
        
        
        np.save('finalphasevel_'+day+mirror_str+'.npy',finalphasevel)
    




# %%  plot an example


plot_for_paper = 'True'
if plot_for_paper == 'True':

    
    date=np.array(['25jul','01aug','07aug','12aug','14aug','20aug','25aug'])
    startphases=np.array([104,271, 54, 60,225, 10, 20])
    endphases=  np.array([193, 21,110,140,304, 62, 89])
    dayno=6
    
    for dd in range(1):
        day=date[dd+dayno]
        start_phase = startphases[dd+dayno]
        end_phase = endphases[dd+dayno]

        blank_phase=np.zeros([88,88])
        
        # full rotation           
            
           
        if end_phase < start_phase: end_phase = end_phase+360
        
        
        n_steps = end_phase-start_phase+1
        
        for i in range(n_steps):
            
            phase=start_phase+i
        
        # newangle = (angle + np.radians(phase))
        
        # print(np.degrees(angle[50,120]),np.degrees(newangle[60,100]))
        
        
        # [newvel90_270,newvel0_180] = pol2cart(mag,newangle)
            
            
            
            # plt.figure()
            # plt.imshow(rs_vel0_180,vmin=-5,vmax=5)
            # plt.figure()
            # plt.imshow(s_vel0_180,vmin=-5,vmax=5)
            
            
            rvel0_180 = ndimage.rotate(s_vel0_180, phase, mode = 'constant')
            rvel90_270 = ndimage.rotate(s_vel90_270, phase-90, mode = 'constant')
            
            
            coi_x = rvel0_180[0,:].size/2
            coi_y = rvel0_180[:,0].size/2
            
            # print(int(coi_x),int(coi_y))
            
            rrvel0_180=rvel0_180[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
            rrvel90_270=rvel90_270[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
            
            
            
            newvel_phase = phase_shift(rrvel0_180,rrvel90_270,phase)
            
            if i == 0:
                start_phase_map = newvel_phase
            if i == n_steps-1:
                end_phase_map = newvel_phase

            blank_phase = blank_phase + newvel_phase
            # fig, (ax1, ax2) = plt.subplots(2)
            # ax1.imshow(newvel_phase,vmin=-5,vmax=5)
            # ax2.imshow(blank_phase/(i+1),vmin=-5,vmax=5)
        blank_phase = blank_phase/n_steps





    fig, ax = plt.subplots(nrows=1,ncols=3, figsize=(6,3.3))
    
    prop = dict(arrowstyle="-|>,head_width=0.4,head_length=0.8",
            shrinkA=0,shrinkB=0,facecolor='purple', edgecolor='black')  
    ax[0].imshow(start_phase_map,vmin=-1.5,vmax=1.5,cmap='seismic_r',origin='lower',extent=[-4.4,+4.4,-1.1,+1.1],aspect=4)
    ax[0].set_xlim(-4,4)
    ax[0].set_ylim(-1.1,1.1)
    ax[0].set_xmargin(20)
    ax[0].set_xlabel('arcsec (x)')
    ax[0].set_ylabel('arcsec (y)')
    
    colorbarvalues = ax[1].imshow(end_phase_map,vmin=-1.5,vmax=1.5,cmap='seismic_r',origin='lower',extent=[-4.4,+4.4,-1.1,+1.1],aspect=4)

    
    ax[1].set_xlim(-4,4)
    ax[1].set_ylim(-1.1,1.1)
    ax[1].set_xlabel('arcsec (x)')
    # ax[1].set_ylabel('arcsec (y)')
    ax[1].tick_params(axis='y',label1On=False)
    
    ax[2].imshow(blank_phase,vmin=-1.5,vmax=1.5,cmap='seismic_r',origin='lower',extent=[-4.4,+4.4,-1.1,+1.1],aspect=4)
    ax[2].set_xlim(-4,4)
    ax[2].set_ylim(-1.1,1.1)
    ax[2].set_xlabel('arcsec (x)')
    # ax[2].set_ylabel('arcsec (y)')
    ax[2].tick_params(axis='y',label1On=False)
    
    # ax[0].arrow(0,0,2*np.sin(np.radians(start_phase)),0.5*np.cos(np.radians(start_phase)),head_width=0.3)
    # ax[1].arrow(0,0,2*np.sin(np.radians(end_phase)),0.5*np.cos(np.radians(end_phase)),head_width=0.2,head_length=0.3)
    ax[0].annotate("", xy=(3*np.sin(np.radians(start_phase)),0.75*np.cos(np.radians(start_phase))), xytext=(0,0), arrowprops=prop)
    ax[1].annotate("", xy=(3*np.sin(np.radians(end_phase)),0.75*np.cos(np.radians(end_phase))), xytext=(0,0), arrowprops=prop)
# plt.annotate("", xy=(.0,.5), xytext=(0,0), arrowprops=prop)


    
    
    plt.subplots_adjust(bottom=0.05, top=0.95)
    
    cax = plt.axes([0.2, 0.05, 0.6, 0.035])
    # cax.set_ymargin(0.0001)
    plt.colorbar(colorbarvalues,cax=cax,label='Velocity km/s',aspect=20,orientation='horizontal')
    
        
    plt.savefig('asd_phase_calc'+mirror_str+'.pdf', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0) 
    plt.show()


print(start_phase,end_phase)
ii=int0+int180
ii[ii>0.004]=0.004
ii[ii<0.00]=0.00
plt.figure()
plt.plot(ii[65,:])



# plt.xlim([75,125])
# plt.plot([119,119],[0,0.004])
# plt.plot([85,85],[0,0.004])


# losvelseeing=np.load('losvelseeing.npy')

# plt.plot(vel0[65,:]-vel0[65,100])
# plt.ylim([-5,5])
# plt.plot(np.arange(losvelseeing[90,:].size)+10,losvelseeing[90,:])
# print("auroral peaks are 34 pixels apart, by eye")

#%% Generate Bin 4 temperature


